from __future__ import division

# If interrupt with ROS, uncomment the followings.
import sys

ros_pth = '/opt/ros/kinetic/lib/python2.7/dist-packages'
if ros_pth in sys.path:
    sys.path.remove('/opt/ros/kinetic/lib/python2.7/dist-packages')

import os
import torch as t
from src.config import opt
from src.head_detector_vgg16 import Head_Detector_VGG16
from trainer import Head_Detector_Trainer
from PIL import Image
import numpy as np
from data.dataset import preprocess
import matplotlib.pyplot as plt
import src.array_tool as at
from src.vis_tool import visdom_bbox
import argparse
import src.utils as utils
from src.config import opt
import time

SAVE_FLAG = 0
THRESH = 0.01
IM_RESIZE = False


# def read_img(path):
#     f = Image.open(path)
#     if IM_RESIZE:
#         f = f.resize((640, 480), Image.ANTIALIAS)
#
#     f.convert('RGB')
#     img_raw = np.asarray(f, dtype=np.uint8)
#     img_raw_final = img_raw.copy()
#     img = np.asarray(f, dtype=np.float32)
#     _, H, W = img.shape
#     img = img.transpose((2, 0, 1))
#     img = preprocess(img)
#     _, o_H, o_W = img.shape
#     scale = o_H / H
#     return img, img_raw_final, scale


def read_img(path):
    f = Image.open(path)
    if IM_RESIZE:
        f = f.resize((640, 480), Image.ANTIALIAS)

    f.convert('RGB')
    img_raw = f.copy()
    # img_raw_final = img_raw.copy()
    img = np.asarray(f, dtype=np.float32)
    _, H, W = img.shape
    img = img.transpose((2, 0, 1))
    img = preprocess(img)
    _, o_H, o_W = img.shape
    scale = o_H / H
    img_raw = img_raw.resize((o_W, o_H), Image.ANTIALIAS)
    return img, img_raw, scale


def detect(img_path, model_path):
    file_id = utils.get_file_id(img_path)
    img, img_raw, scale = read_img(img_path)
    head_detector = Head_Detector_VGG16(ratios=[1], anchor_scales=[2, 4])
    trainer = Head_Detector_Trainer(head_detector).cuda()
    trainer.load(model_path)
    img = at.totensor(img)
    img = img[None, :, :, :]
    img = img.cuda().float()
    st = time.time()
    pred_bboxes_, _ = head_detector.predict(img, scale, mode='evaluate', thresh=THRESH)
    et = time.time()
    tt = et - st
    print("[INFO] Head detection over. Time taken: {:.4f} s".format(tt))
    for i in range(pred_bboxes_.shape[0]):
        ymin, xmin, ymax, xmax = pred_bboxes_[i, :]
        # utils.draw_bounding_box_on_image_array(img_raw, ymin, xmin, ymax, xmax)
        utils.draw_bounding_box_on_image(img_raw, ymin, xmin, ymax, xmax)

    plt.axis('off')
    plt.imshow(img_raw)
    if SAVE_FLAG == 1:
        plt.savefig(os.path.join(opt.test_output_path, file_id + '_tested_' + str(THRESH) + '.png'),
                    bbox_inches='tight', pad_inches=0)
    else:
        plt.show()


def detect_path(path_img, model_path):
    plt.ion()
    plt.figure(1)

    head_detector = Head_Detector_VGG16(ratios=[1], anchor_scales=[2, 4])
    trainer = Head_Detector_Trainer(head_detector).cuda()
    trainer.load(model_path)

    for file in sorted(os.listdir(path_img)):
        print(f'[INFO] Detecting {file} ...')
        img_path = os.path.join(path_img, file)
        file_id = utils.get_file_id(img_path)
        img, img_raw, scale = read_img(img_path)
        img = at.totensor(img)
        img = img[None, :, :, :]
        img = img.cuda().float()
        st = time.time()
        pred_bboxes_, _ = head_detector.predict(img, scale, mode='evaluate', thresh=THRESH)
        et = time.time()
        tt = et - st
        print("[INFO] Head detection over. Time taken: {:.4f} s".format(tt))
        for i in range(pred_bboxes_.shape[0]):
            ymin, xmin, ymax, xmax = pred_bboxes_[i, :]
            # utils.draw_bounding_box_on_image_array(img_raw, ymin, xmin, ymax, xmax)
            utils.draw_bounding_box_on_image(img_raw, ymin, xmin, ymax, xmax)

        plt.axis('off')
        plt.imshow(img_raw)
        if SAVE_FLAG == 1:
            plt.savefig(os.path.join(opt.test_output_path, file_id + '_tested_' + str(THRESH) + '.png'),
                        bbox_inches='tight', pad_inches=0)
        else:
            plt.pause(0.2)
            plt.clf()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--img_path", type=str, help="test image path", default="./image.jpg")
    parser.add_argument("--model_path", type=str, default='./checkpoints/head_detector_final')
    args = parser.parse_args()
    # detect(args.img_path, args.model_path)

    path = opt.brainwash_dataset_root_path + '/brainwash_10_27_2014_images'
    detect_path(path, args.model_path)
